import h5py
from pytorch3d.io import save_ply, load_ply
import torch 
import torch
import numpy as np
from IPython import embed
from pytorch3d.loss import chamfer_distance
import glob

pc_file = '/home/tiangel/datasets/shapeglot_pc_v2_train.h5'
pcs = np.array(h5py.File(pc_file, 'r')['data'])
pcs = torch.Tensor(pcs).cuda()

ori_pcs_paths = glob.glob('/home/tiangel/datasets/train_pc_10000p_4k/*.ply')
pcs_name = []
for p in ori_pcs_paths:
    pcs_name.append(p.split('/')[-1].split('.')[0])

ori_pcs = []
for p in ori_pcs_paths:
    ori_pcs.append(load_ply(p)[0].unsqueeze(0))
ori_pcs = torch.cat(ori_pcs).cuda()

idx_list = []
cd_list = []
bs = 4
tile_gt = torch.tile(ori_pcs, (bs,1,1))

with torch.no_grad():
    for k in range(int(np.ceil(pcs.shape[0]/bs))):
        output = pcs[k*bs: (k+1)*bs]
        tile_our = torch.tile(output.unsqueeze(0), (ori_pcs.shape[0], 1, 1, 1)).transpose(0, 1).reshape(
                                                -1, output.shape[1], output.shape[2])
    
        cd_dis = chamfer_distance(tile_our, tile_gt, batch_reduction=None)[0]
        cd_dis = cd_dis.reshape(-1, ori_pcs.shape[0])
        mincd, minidx = torch.min(cd_dis, 1)
    
        for j in minidx:
            idx_list.append(j)
        for j in mincd:
            cd_list.append(j)

embed()